import paddle
from src.model.ssbase import ConvNorm, ProjUint, SSRSplit, SSRMerge

class SSRBasic(paddle.nn.Layer):
    def __init__(self, in_dim, ch_dim, ou_dim, stride=1, splits=0, direct=True):
        """
        初始化重复基础
        params:
        - in_dim (int): 输入维度
        - ch_dim (int): 通道维度
        - ou_dim (int): 输出维度
        - stride (int): 滑动步长
        - splits (int): 分割次数
        - direct(bool): 直连标识
        """
        super().__init__()
        # 添加投影单元
        self.is_pass = direct                                            # 直连标志
        self.project = ProjUint(in_dim, ou_dim, stride)                  # 投影单元
        
        # 添加卷积单元
        self.convbn0 = SSRSplit(in_dim, ch_dim, stride, splits)          # 分割单元
        self.convbn1 = ConvNorm(ch_dim, ou_dim, kernel_size=1, stride=1) # 卷积单元，卷积核为1*1
        
    def forward(self, x):
        # 直连路径
        if self.is_pass:
            y = x
        else:
            y = self.project(x)
        
        # 卷积路径
        x = self.convbn0(x)
        x = self.convbn1(x)
        x = paddle.nn.functional.leaky_relu(x + y)
        
        return x

class SSRBlock(paddle.nn.Layer):
    def __init__(self, in_dim, ch_dim, ou_dim, reduce=0, repeat=0, splits=0):
        """
        初始化重复模块
        params:
        - in_dim(int): 输入维度
        - ch_dim(int): 通道维度
        - ou_dim(int): 输出维度
        - reduce(int): 缩小次数，缩小倍数为2^n
        - repeat(int): 重复次数，必须大于等于0
        - splits(int): 分割次数，分割尺度为2^n
        """
        super().__init__()
        # 设置输入参数
        assert reduce >= 0, '错误：缩小次数必须大于等于0！'
        assert repeat >= 0, '错误：重复次数必须大于等于0！'
        
        # 添加缩放项目
        self.block_item0 = SSRBasic(
            in_dim, ch_dim, ou_dim, stride=(2 if reduce > 0 else 1), splits=splits, direct=False
        )
        
        self.block_list1 = [] # 缩放项目列表
        for i in range(1, reduce): # 当缩小次数大于等于2时
            block_item1 = self.add_sublayer(
                'block_item1_' + str(i),
                SSRBasic(ou_dim, ch_dim, ou_dim, stride=2, splits=splits, direct=False)
            )
            self.block_list1.append(block_item1)
        
        # 添加重复项目
        self.block_list2 = [] # 重复项目列表
        for i in range(repeat):    # 当重复次数大于等于1时
            block_item2 = self.add_sublayer(
                'block_item2_' + str(i),
                SSRBasic(ou_dim, ch_dim, ou_dim, stride=1, splits=splits, direct=True)
            )
            self.block_list2.append(block_item2)
    
    def forward(self, x):
        # 缩放特征
        x = self.block_item0(x)

        for block_item1 in self.block_list1:
            x = block_item1(x) # 当缩小次数大于等于2时
            
        # 重复特征
        for block_item2 in self.block_list2:
            x = block_item2(x) # 当重复次数大于等于1时
            
        return x

class SSCBasic(paddle.nn.Layer):
    def __init__(self, in_dim, ch_dim, ou_dim, stride=1, splits=0, direct=True):
        """
        初始化循环基础
        params:
        - in_dim (int): 输入维度
        - ch_dim (int): 通道维度
        - ou_dim (int): 输出维度
        - stride (int): 滑动步长
        - splits (int): 分割次数
        - direct(bool): 直连标识
        """
        super().__init__()
        # 添加投影单元
        self.is_pass = direct                                            # 直连标志
        self.project = ProjUint(in_dim, ou_dim, stride)                  # 投影单元
        
        # 添加卷积单元
        self.convbn0 = SSRSplit(in_dim, ch_dim, stride, splits)          # 分割单元
        self.convbn1 = ConvNorm(ch_dim, ou_dim, kernel_size=3, stride=1) # 卷积单元，卷积核为3*3
        
    def forward(self, x):
        # 直连路径
        if self.is_pass:
            y = x
        else:
            y = self.project(x)
        
        # 卷积路径
        x = self.convbn0(x)
        x = self.convbn1(x)
        x = paddle.nn.functional.swish(x + y)
        
        return x    
    
class SSCBlock(paddle.nn.Layer):
    def __init__(self, in_dim, ch_dim, ou_dim, reduce=0, repeat=0, splits=0):
        """
        初始化循环模块
        params:
        - in_dim(int): 输入维度
        - ch_dim(int): 通道维度
        - ou_dim(int): 输出维度
        - reduce(int): 缩小次数，缩小倍数为2^n
        - repeat(int): 重复次数，必须大于等于0
        - splits(int): 分割次数，分割尺度为2^n
        """
        super().__init__()
        # 设置输入参数
        assert reduce >= 0, '错误：缩小次数必须大于等于0！'
        assert repeat >= 0, '错误：重复次数必须大于等于0！'
        
        self.reduce = reduce # 缩小次数
        self.repeat = repeat # 重复次数
        
        # 添加缩放项目
        self.block_item0 = SSCBasic(
            in_dim, ch_dim, ou_dim, stride=(2 if self.reduce > 0 else 1), splits=splits, direct=False
        )
        
        if self.reduce > 1: # 当缩小次数大于等于2时
            self.block_item1 = SSCBasic(ou_dim, ch_dim, ou_dim, stride=2, splits=splits, direct=False)
        
        # 添加重复项目
        if self.repeat > 0: # 当重复次数大于等于1时
            self.block_item2 = SSCBasic(ou_dim, ch_dim, ou_dim, stride=1, splits=splits, direct=True)
            
    def forward(self, x):
        # 缩放特征
        x = self.block_item0(x)

        for i in range(1, self.reduce):
            x = self.block_item1(x) # 当缩小次数大于等于2时
            
        # 重复特征
        for i in range(self.repeat):
            x = self.block_item2(x) # 当重复次数大于等于1时
            
        return x
    
class SSRNet(paddle.nn.Layer):
    def __init__(self, 
                 group_arch=[[3, 64, 256, 3, 8, 2], [128, 64, 256, 1, 4, 2], [128, 64, 256, 1, 2, 2]],
                 block_mode='ssr'):
        """
        初始化网络模型
        params:
        - group_arch(list): 特征块组结构：输入维度，通道维度，输出维度，缩小次数，重复次数，分割次数
        - block_mode (str): 特征模块模式：必须为'ssr'或'ssc'
        """
        super().__init__()
        # 设置模组变量
        assert block_mode in ['ssr','ssc'], "错误：模块模式必须为'ssr'或'ssc'"
        self.splits = len(group_arch) - 1 # 分割次数
        # self.dimensions = []            # 维度列表
        self.group_list = []              # 模组列表
        self.merge_list = []              # 合并列表
        
        # 添加模组列表
        for i, block_arch in enumerate(group_arch):
            # 添加维度列表
            if i < self.splits:
                ou_dim = block_arch[2]//2
                # self.dimensions.append(ou_dim)
            
            # 添加模组列表
            if block_mode == 'ssr':
                group_item = self.add_sublayer(
                    'group_' + str(i),
                    SSRBlock(
                        in_dim=block_arch[0],
                        ch_dim=block_arch[1],
                        ou_dim=block_arch[2],
                        reduce=block_arch[3],
                        repeat=block_arch[4],
                        splits=block_arch[5]
                    )
                )
            else:
                group_item = self.add_sublayer(
                    'group_' + str(i),
                    SSCBlock(
                        in_dim=block_arch[0],
                        ch_dim=block_arch[1],
                        ou_dim=block_arch[2],
                        reduce=block_arch[3],
                        repeat=block_arch[4],
                        splits=block_arch[5]
                    )
                )
            self.group_list.append(group_item)
            
            # 添加合并列表
            if i > 0:
                merge_item = self.add_sublayer(
                    'merge_' + str(self.splits - i),
                    SSRMerge(in_dim=block_arch[2], ou_dim=ou_dim, expand=2**block_arch[3])
                )
                self.merge_list.insert(0, merge_item)
            
    def forward(self, x):
        # 提取特征
        x_list = []  # 特征列表
        for i, group_item in enumerate(self.group_list):
            if i < self.splits:
                x = group_item(x)
                # x_item, x = paddle.split(x, num_or_sections=[-1, self.dimensions[i]], axis=1)
                x_item, x = paddle.split(x, num_or_sections=2, axis=1)# 使用该接口时，通道维度为2^n
                x_list.insert(0, x_item)
            else:
                x = group_item(x)
        
        # 合并特征
        c_list = [x] # 输出列表
        for i, merge_item in enumerate(self.merge_list):
            x = merge_item(x_list[i], x)
            c_list.append(x)
            
        return c_list